import sqlite3
import os, re, time, sys
import json
import pandas as pd
import multiprocessing as mp
import spacy
from functools import partial

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from utils.config import load_config
from utils.logger import set_logger

LOG_PATH = '../../../logs/02_word_frequency.log'

# Load the configuration file
config = load_config()
logger = set_logger(LOG_PATH)

# Load the Spanish model in spacy after installing it with `python -m spacy download es_core_news_sm`
nlp = spacy.load('es_core_news_sm')

# Setting the working directory to the directory of the script
os.chdir(os.path.dirname(__file__))

# CONSTANTS
COMMENTS_TABLE = 'comments_chunk_1'
START_DATETIME = config['start_datetime']  # '2023-01-01T00:00:00Z'
END_DATETIME = config['end_datetime']  # '2023-12-31T23:59:59Z'
COUNTRY_NAME = config['country_name']  # e.g. "mexico"
COUNTRY_CODE = config['country_code']  # e.g. "mx"

# Convert the start and end datetimes to the format `YYYYMM`
START_YEAR_MONTH = re.sub(r'[^0-9]', '', START_DATETIME[:7])
END_YEAR_MONTH = re.sub(r'[^0-9]', '', END_DATETIME[:7])

# PATHS
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
WORDNET_DIR = os.path.join(PROJECT_ROOT, '../../../data/wordnet')
WEIGHTED_DICT_PATH = os.path.join(PROJECT_ROOT, '../../../output/weighted_dictionary_es.json')
OUTPUT_DIR = os.path.join(PROJECT_ROOT, '../../../output', COUNTRY_NAME)
VIDEO_DB_PATH = os.path.join(OUTPUT_DIR, f'{START_YEAR_MONTH[:4]}-{START_YEAR_MONTH[4:]}-collection/{START_YEAR_MONTH}_{END_YEAR_MONTH}_02_{COUNTRY_CODE}_youtube_data.db')
OUTPUT_CSV_FILE = os.path.join(OUTPUT_DIR, f'{START_YEAR_MONTH[:4]}-{START_YEAR_MONTH[4:]}-collection/{START_YEAR_MONTH}_{END_YEAR_MONTH}_04_{COUNTRY_CODE}_comments_count.csv')

# Function to read the weighted dictionary from a JSON file
def read_weighted_dict(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

# Function to lemmatize Spanish text and count word occurrences
def lemmatize_and_count(text, weighted_dict):
    text = text.lower()
    doc = nlp(text)
    # print(doc)
    lemmatized_tokens = [token.lemma_ for token in doc]
    # print(lemmatized_tokens)
    word_count = {word: lemmatized_tokens.count(word) for word in lemmatized_tokens if word in weighted_dict}
    return word_count

# Function to process each record
def process_record(record, weighted_dict):
    comment_id, video_id, text_original, author_display_name, published_at = record
    word_count = lemmatize_and_count(text_original, weighted_dict)
    if word_count:  # Only return if word_count is not empty
        logger.info("Processing comment_id: ", comment_id)
        return {
            "comment_id": comment_id,
            "video_id": video_id,
            "author_display_name": author_display_name,
            "published_at": published_at,
            "word_count": word_count
        }
    return None

# Function to fetch records from the database
def fetch_records_from_db(db_path, table_name):
    if not os.path.exists(db_path):
        logger.error(f"Database not found: {db_path}")
        sys.exit(1)
    logger.info(f"Connecting to database: {db_path}")
    # Use a context manager to handle the database connection
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute(f"SELECT comment_id, video_id, text_original, author_display_name, published_at FROM {table_name}")
    records = cursor.fetchall()
    conn.close()
    return records

if __name__ == '__main__':
    start_time = time.time()
    # Read the weighted dictionary
    weighted_dict = read_weighted_dict(WEIGHTED_DICT_PATH)

    # Fetch records from the database
    records = fetch_records_from_db(VIDEO_DB_PATH, COMMENTS_TABLE)
    logger.info(f"Total records fetched: {len(records)}")
    if len(records) == 0:
        logger.warning("No records found in the database.")
        sys.exit(1)
    logger.info("Processing records...")

    # Use functools.partial to pass weighted_dict to process_record
    process_record_with_dict = partial(process_record, weighted_dict=weighted_dict)

    # Use multiprocessing to process records in parallel
    with mp.Pool(mp.cpu_count()) as pool:
        results = pool.map(process_record_with_dict, records)

    # Filter out None results
    filtered_results = [result for result in results if result is not None]

    # Convert results to a DataFrame and save to a CSV file
    df = pd.DataFrame(filtered_results)
    logger.info(f"Writing to {OUTPUT_CSV_FILE}...")
    df.to_csv(OUTPUT_CSV_FILE, index=False)
    logger.info("Done.")
    logger.info(f"Execution time: {time.time() - start_time:.2f} seconds")
